{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZVJCNxUcVkkm"
   },
   "source": [
    "# Explicit sharding (a.k.a. \"sharding in types\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ATLBMlw3VcCJ"
   },
   "source": [
    "JAX's traditional automatic sharding leaves sharding decisions to the compiler.\n",
    "You can provide hints to the compiler using\n",
    "`jax.lax.with_sharding_constraint` but for the most part you're supposed to be\n",
    "focussed on the math while the compiler worries about sharding.\n",
    "\n",
    "But what if you have a strong opinion about how you want your program sharded?\n",
    "With enough calls to `with_sharding_constraint` you can probably guide the\n",
    "compiler's hand to make it do what you want. But \"compiler tickling\" is\n",
    "famously not a fun programming model. Where should you put the sharding\n",
    "constraints? You could put them on every single intermediate but that's a lot\n",
    "of work and it's also easy to make mistakes that way because there's no way to\n",
    "check that the shardings make sense together. More commonly, people add just\n",
    "enough sharding annotations to constrain the compiler. But this is a slow\n",
    "iterative process. It's hard to know ahead of time what XLA's GSPMD pass will\n",
    "do (it's a whole-program optimization) so all you can do is add annotations,\n",
    "inspect XLA's sharding choices to see what happened, and repeat.\n",
    "\n",
    "To fix this we've come up with a different style of sharding programming we\n",
    "call \"explicit sharding\" or \"sharding in types\". The idea is that sharding\n",
    "propagation happens at the JAX level at trace time. Each JAX operation has a\n",
    "sharding rule that takes the shardings of the op's arguments and produces a\n",
    "sharding for the op's result. For most operations these rules are simple and\n",
    "obvious because there's only one reasonable choice. But for some operations it's\n",
    "unclear how to shard the result. In that case we ask the programmer\n",
    "to provide an `out_sharding` argument explicitly and we throw a (trace-time)\n",
    "error otherwise. Since the shardings are propagated at trace time they can\n",
    "also be _queried_ at trace time too. In the rest of this doc we'll describe\n",
    "how to use explicit sharding mode. Note that this is a new feature so we\n",
    "expect there to be bugs and unimplemented cases. Please let us know when you\n",
    "find something that doesn't work! Also see {doc}`../the-training-cookbook`\n",
    "for a real-world machine learning training example that uses explicit sharding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "hVi6mApuVw3r"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard\n",
    "\n",
    "jax.config.update('jax_num_cpu_devices', 8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oU5O6yOLWqbP"
   },
   "source": [
    "## Setting up an explicit mesh\n",
    "\n",
    "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n",
    "the JAX-level _type_ of a value includes a description of how the value is sharded.\n",
    "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n",
    "scalar) using `jax.typeof`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mzDIDvj7Vw0k",
    "outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JAX-level type of some_array: ShapedArray(int32[8])\n"
     ]
    }
   ],
   "source": [
    "some_array = np.arange(8)\n",
    "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TZzp_1sXW061"
   },
   "source": [
    "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n",
    "is almost _defined_ as \"the information about a value we have access to while\n",
    "under a jit)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "IyPx_-IBVwxr",
    "outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JAX-level type of x during tracing: ShapedArray(int32[8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@jax.jit\n",
    "def foo(x):\n",
    "  print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n",
    "  return x + x\n",
    "\n",
    "foo(some_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c3gNPzfZW45K"
   },
   "source": [
    "These types show the shape and dtype of array but they don't appear to\n",
    "show sharding. (Actually, they _did_ show sharding, but the shardings were\n",
    "trivial. See \"Concrete array shardings\", below.) To start seeing some\n",
    "interesting shardings we need to set up an explicit-sharding mesh.\n",
    "\n",
    "`jax.set_mesh` can be used as a global setter or a context manager. We use\n",
    "`jax.set_mesh` in this notebook as a global setter. You can use it as a scoped\n",
    "context manager via `with jax.set_mesh(mesh)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NO2ulM_QW7a8",
    "outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n"
     ]
    }
   ],
   "source": [
    "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n",
    "                     axis_types=(AxisType.Explicit, AxisType.Explicit))\n",
    "jax.set_mesh(mesh)\n",
    "\n",
    "print(f\"Current mesh is: {get_abstract_mesh()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "V7bVz6tzW_Eb"
   },
   "source": [
    "Now we can create some sharded arrays using `reshard`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1-TzmA0AXCAf",
    "outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "replicated_array type: ShapedArray(int32[4,2])\n",
      "sharded_array type: ShapedArray(int32[4@X,2])\n"
     ]
    }
   ],
   "source": [
    "replicated_array = np.arange(8).reshape(4, 2)\n",
    "sharded_array = reshard(replicated_array, P(\"X\", None))\n",
    "\n",
    "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n",
    "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B0jBBXtgXBxr"
   },
   "source": [
    "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n",
    "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n",
    "axes\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N8yMauHAXKtX"
   },
   "source": [
    "These shardings associated with JAX-level types propagate through operations. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Gy7ABds3XND3",
    "outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arg0 sharding: ShapedArray(int32[4@X,1])\n",
      "arg1 sharding: ShapedArray(int32[1,8@Y])\n",
      "result sharding: ShapedArray(int32[4@X,8@Y])\n"
     ]
    }
   ],
   "source": [
    "arg0 = reshard(np.arange(4).reshape(4, 1), P(\"X\", None))\n",
    "arg1 = reshard(np.arange(8).reshape(1, 8), P(None, \"Y\"))\n",
    "\n",
    "result = arg0 + arg1\n",
    "\n",
    "print(f\"arg0 sharding: {jax.typeof(arg0)}\")\n",
    "print(f\"arg1 sharding: {jax.typeof(arg1)}\")\n",
    "print(f\"result sharding: {jax.typeof(result)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lwsygUmVXPCk"
   },
   "source": [
    "We can do the same type querying under a jit:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "grCcotr-XQjY",
    "outputId": "c2db656c-809f-49a6-c948-629d6420360c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x sharding: ShapedArray(int32[4@X,1])\n",
      "y sharding: ShapedArray(int32[1,8@Y])\n",
      "ans sharding: ShapedArray(int32[4@X,8@Y])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([[ 0,  1,  2,  3,  4,  5,  6,  7],\n",
       "       [ 1,  2,  3,  4,  5,  6,  7,  8],\n",
       "       [ 2,  3,  4,  5,  6,  7,  8,  9],\n",
       "       [ 3,  4,  5,  6,  7,  8,  9, 10]], dtype=int32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@jax.jit\n",
    "def add_arrays(x, y):\n",
    "  ans = x + y\n",
    "  print(f\"x sharding: {jax.typeof(x)}\")\n",
    "  print(f\"y sharding: {jax.typeof(y)}\")\n",
    "  print(f\"ans sharding: {jax.typeof(ans)}\")\n",
    "  return ans\n",
    "\n",
    "add_arrays(arg0, arg1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lVd6a5ufXZoH"
   },
   "source": [
    "That's the gist of it. Shardings propagate deterministically at trace time and\n",
    "we can query them at trace time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ETtwK3LCXSkd"
   },
   "source": [
    "## Sharding rules and operations with ambiguous sharding\n",
    "\n",
    "Each op has a sharding rule which specifies its output sharding given its\n",
    "input shardings. A sharding rule may also throw a (trace-time) error. Each op\n",
    "is free to implement whatever sharding rule it likes, but the usual pattern is\n",
    "the following: For each output axis we identify zero of more corresponding\n",
    "input axes. The output axis is then\n",
    "sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's\n",
    "`None` if the input shardings are all `None`, and it's the common non-None input sharding\n",
    "if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "an8-Fq1uXehp"
   },
   "source": [
    "This procedure is done on an axis-by-axis basis. When it’s done, we might end\n",
    "up with an array sharding that mentions a mesh axis more than once, which is\n",
    "illegal. In that case we raise a (trace-time) sharding error and ask for an\n",
    "explicit out_sharding.\n",
    "\n",
    "Here are some example sharding rules:\n",
    "   * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole\n",
    "     cloth so they don’t have input shardings to propagate. Their output is\n",
    "     unsharded by default unless overridden by the out_sharding kwarg.\n",
    "   * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the\n",
    "     input.\n",
    "   * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions\n",
    "     must match (or be `None`). “Outer product” dimensions (dimensions that\n",
    "     appear in only one argument) are sharded as they are in the input. If the\n",
    "     result ends up mentioning a mesh axis more than once it's an error.\n",
    "   * `reshape.` Reshape is a particularly tricky op. An output axis can map to more\n",
    "     than one input axis (when reshape is used to merge axes) or just a part\n",
    "     of an input axis (when reshape is used to split axes). Our usual rules\n",
    "     don’t apply. Instead we treat reshape as follows. We strip away singleton\n",
    "     axes (these can’t be sharded anyway. Then\n",
    "     we decide whether the reshape is a “split” (splitting a single axis into\n",
    "     two or more adjacent axes), a “merge” (merging two or more adjacent axes\n",
    "     into a single one) or something else. If we have a split or merge case in\n",
    "     which the split/merged axes are sharded as None then we shard the\n",
    "     resulting split/merged axes as None and the other axes according to their\n",
    "     corresponding input axis shardings. In all other cases we throw an error\n",
    "     and require the user to provide an `out_sharding` argument."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jZMp6w48Xmd7"
   },
   "source": [
    "## JAX transformations and higher-order functions\n",
    "\n",
    "The staged-out representation of JAX programs is explicitly typed. (We call\n",
    "the types “avals” but that’s not important.) In explicit-sharding mode, the\n",
    "sharding is part of that type. This means that shardings need to match\n",
    "wherever types need to match. For example, the two sides of a `lax.cond` need to\n",
    "have results with matching shardings. And the carry of `lax.scan` needs to have the\n",
    "same sharding at the input and the output of the scan body. And when you\n",
    "construct a jaxpr without concrete arguments using `make_jaxpr` you need to\n",
    "provide shardings too. Certain JAX transformations perform type-level\n",
    "operations. Automatic differentation constructs a tangent type for each primal\n",
    "type in the original computation (e.g. `TangentOf(float) == float`,\n",
    "`TangentOf(int) == float0`). With sharding in the types, this means that tangent\n",
    "values are sharded in the same way as their primal values. Vmap and scan also\n",
    "do type-level operations, they lift an array shape to a rank-augmented version\n",
    "of that shape. That extra array axis needs a sharding. We can infer it from the\n",
    "arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan\n",
    "needs an explicit sharding argument just as it needs an explicit length\n",
    "argument."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ERJx4p0tXoS3"
   },
   "source": [
    "## Working around unimplemented sharding rules using `auto_axes`\n",
    "\n",
    "The implementation of explicit sharding is still a work-in-progress and there\n",
    "are plenty of ops that are missing sharding rules. For example, `scatter` and\n",
    "`gather` (i.e. indexing ops).\n",
    "\n",
    "Normally we wouldn't suggest using a feature with so many unimplemented cases,\n",
    "but in this instance there's a reasonable fallback you can use: `auto_axes`.\n",
    "The idea is that you can temporarily drop into a context where the mesh axes\n",
    "are \"auto\" rather than \"explicit\". You explicitly specify how you intend the\n",
    "final result of the `auto_axes` to be sharded as it gets returned to the calling context.\n",
    "\n",
    "This works as a fallback for ops with unimplemented sharding rules. It also\n",
    "works when you want to override the sharding-in-types type system. For\n",
    "example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our\n",
    "sharding rule for addition would throw an error: the result would need to be\n",
    "`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you\n",
    "want to perform the operation anyway, and you want the result to be sharded along\n",
    "the first axis only, like `f32[4@X, 4]`. You can do this as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "fpFEaMBcXsJG",
    "outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ERROR!\n",
      "add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]\n",
      "=== try again with auto_axes ===\n",
      "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
      "Result type: ShapedArray(int32[4@X,4])\n"
     ]
    }
   ],
   "source": [
    "from jax.sharding import auto_axes, explicit_axes\n",
    "\n",
    "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n",
    "some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n",
    "\n",
    "try:\n",
    "  some_x + some_y\n",
    "except Exception as e:\n",
    "  print(\"ERROR!\")\n",
    "  print(e)\n",
    "\n",
    "print(\"=== try again with auto_axes ===\")\n",
    "\n",
    "@auto_axes\n",
    "def add_with_out_sharding_kwarg(x, y):\n",
    "  print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n",
    "  return x + y\n",
    "\n",
    "result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P(\"X\", None))\n",
    "print(f\"Result type: {jax.typeof(result)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8-_zDr-AXvb6"
   },
   "source": [
    "## Using a mixture of sharding modes\n",
    "\n",
    "JAX now has three styles of parallelism:\n",
    "\n",
    " * *Automatic sharding* is where you treat all the devices as a single logical\n",
    "   machine and write a \"global view\" array program for that machine. The\n",
    "   compiler decides how to partition the data and computation across the\n",
    "   available devices. You can give hints to the compiler using\n",
    "   `with_sharding_constraint`.\n",
    " * *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n",
    "   you're writing a global-view program. The difference is that the sharding\n",
    "   of each array is part of the array's JAX-level type making it an explicit\n",
    "   part of the programming model. These shardings are propagated at the JAX\n",
    "   level and queryable at trace time. It's still the compiler's responsibility\n",
    "   to turn the whole-array program into per-device programs (turning `jnp.sum`\n",
    "   into `psum` for example) but the compiler is heavily constrained by the\n",
    "   user-supplied shardings.\n",
    " * *Manual Sharding* (`shard_map`) is where you write a program from the\n",
    "   perspective of a single device. Communication between devices happens via\n",
    "   explicit collective operations like psum.\n",
    "\n",
    "A summary table:\n",
    "\n",
    "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n",
    "|---|---|---|---|\n",
    "| Auto | Global | ❌ | ❌ |\n",
    "| Explicit | Global | ✅ | ❌ |\n",
    "| Manual | Per-device | ✅ | ✅ |\n",
    "\n",
    "The current mesh tells us which sharding mode we're in. We can query it with\n",
    "`get_abstract_mesh`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "geptWrdYX0OM",
    "outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n"
     ]
    }
   ],
   "source": [
    "print(f\"Current mesh is: {get_abstract_mesh()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AQQjzUeGX4P6"
   },
   "source": [
    "Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit\n",
    "mode. Notice that the sharding mode is associated with a mesh _axis_, not the\n",
    "mesh as a whole. We can actually mix sharding modes by having a different\n",
    "sharding mode for each mesh axis. Shardings (on JAX-level types) can only\n",
    "mention _explicit_ mesh axes and collective operations like `psum` can only\n",
    "mention _manual_ mesh axes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LZWjgiMZ7uSS"
   },
   "source": [
    "You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "IVzPSkp77uCF",
    "outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
      "x.sharding: ShapedArray(float32[4@X,4@Y])\n",
      "\n",
      "mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
      "y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
      "\n",
      "z.sharding: ShapedArray(float32[4@X,4@Y])\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([[ 1.        ,  2.682942  ,  2.818595  ,  1.28224   ],\n",
       "       [-0.513605  , -0.9178486 ,  0.44116902,  2.3139732 ],\n",
       "       [ 2.9787164 ,  1.824237  , -0.08804226, -0.99998045],\n",
       "       [-0.07314587,  1.840334  ,  2.9812148 ,  2.3005757 ]],      dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import functools\n",
    "\n",
    "@functools.partial(auto_axes, axes='X')\n",
    "def g(y):\n",
    "  print(f'mesh inside g: {get_abstract_mesh()}')\n",
    "  print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
    "  return y * 2\n",
    "\n",
    "@jax.jit\n",
    "def f(arr1):\n",
    "  print(f'mesh inside f: {get_abstract_mesh()}')\n",
    "  x = jnp.sin(arr1)\n",
    "  print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
    "\n",
    "  z = g(x, out_sharding=P(\"X\", \"Y\"))\n",
    "\n",
    "  print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
    "  return z + 1\n",
    "\n",
    "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
    "f(some_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_3sfJjRq8w9f"
   },
   "source": [
    "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n",
    "\n",
    "\n",
    "You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a102e9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n",
    "                           axis_types=(AxisType.Auto, AxisType.Auto))\n",
    "\n",
    "@functools.partial(explicit_axes, axes=('X', 'Y'))\n",
    "def explicit_g(y):\n",
    "  print(f'mesh inside g: {get_abstract_mesh()}')\n",
    "  print(f'y.sharding inside g: {jax.typeof(y) = }')\n",
    "  z = y * 2\n",
    "  print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n",
    "  return z\n",
    "\n",
    "@jax.jit\n",
    "def f(arr1):\n",
    "  print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n",
    "  x = jnp.sin(arr1)\n",
    "\n",
    "  z = explicit_g(x, in_sharding=P(\"X\", \"Y\"))\n",
    "\n",
    "  return z + 1\n",
    "\n",
    "with jax.set_mesh(auto_mesh):\n",
    "  some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
    "  f(some_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e64d40de",
   "metadata": {},
   "source": [
    "As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n",
    "Because of that, sharding is visible on the type of arrays inside `g`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sJcWbfAh7UcO"
   },
   "source": [
    "## Concrete array shardings can mention `Auto` mesh axis\n",
    "\n",
    "You can query the sharding of a concrete array `x` with `x.sharding`. You\n",
    "might expect the result to be the same as the sharding associated with the\n",
    "value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along\n",
    "both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler\n",
    "eventually chose. Whereas the type-specificed sharding,\n",
    "`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh\n",
    "axes. The `Auto` axes are deliberately hidden from the type because they're\n",
    "the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,\n",
    "the type-specified sharding. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ivLl6bxmX7EZ",
    "outputId": "6d7b7fce-68b6-47f1-b214-d62bda8d7b6e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===\n",
      "Concrete value sharding: PartitionSpec('X',)\n",
      "Type-specified sharding: PartitionSpec('X',)\n",
      "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===\n",
      "Concrete value sharding: PartitionSpec('X',)\n",
      "Type-specified sharding: PartitionSpec(None,)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([ 0.        ,  0.84147096,  0.9092974 ,  0.14112   , -0.7568025 ,\n",
       "       -0.9589243 , -0.2794155 ,  0.6569866 ], dtype=float32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "Array([ 0.        ,  0.84147096,  0.9092974 ,  0.14112   , -0.7568025 ,\n",
       "       -0.9589243 , -0.2794155 ,  0.6569866 ], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compare_shardings(x):\n",
    "  print(f\"=== with mesh: {get_abstract_mesh()} ===\")\n",
    "  print(f\"Concrete value sharding: {x.sharding.spec}\")\n",
    "  print(f\"Type-specified sharding: {jax.typeof(x).sharding.spec}\")\n",
    "\n",
    "my_array = jnp.sin(reshard(np.arange(8), P(\"X\")))\n",
    "compare_shardings(my_array)\n",
    "\n",
    "@auto_axes\n",
    "def check_in_auto_context(x):\n",
    "  compare_shardings(x)\n",
    "  return x\n",
    "\n",
    "check_in_auto_context(my_array, out_sharding=P(\"X\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MRFccsi5X8so"
   },
   "source": [
    "Notice that at the top level, where we're currently in a fully `Explicit` mesh\n",
    "context, the concrete array sharding and type-specified sharding agree. But\n",
    "under the `auto_axes` decorator we're in a fully `Auto` mesh context and the\n",
    "two shardings disagree: the type-specified sharding is `P(None)` whereas the\n",
    "concrete array sharding is `P(\"X\")` (though it could be anything! It's up to\n",
    "the compiler)."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
